{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KQALG9h23b0R"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "both",
        "id": "U34SJW0W3dg_"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VIX1XZHJ3gFo"
      },
      "source": [
        "# TensorFlow NumPy: Distributed Image Classification Tutorial"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "f7NApJ7R3ndN"
      },
      "source": [
        "## Overview\n",
        "\n",
        "TensorFlow implements a subset of the [NumPy API](https://numpy.org/doc/1.16), available as `tf.experimental.numpy`. This allows running NumPy code, accelerated by TensorFlow together with access to all of TensorFlow's APIs. Please see [TensorFlow NumPy Guide](https://www.tensorflow.org/guide/tf_numpy) to get started.\n",
        "\n",
        "Here you will learn how to build a deep model for an image classification task by using TensorFlow Numpy APIs. For using higher level `tf.keras` APIs, see the following [tutorial](https://www.tensorflow.org/tutorials/quickstart/beginner)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IYDdfih63rSG"
      },
      "source": [
        "## Setup\n",
        "\n",
        "tf.experimental.numpy will be available in the stable branch starting from TensorFlow 2.4. For now, it is available in  nightly."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "3IlLM-YlTMv5"
      },
      "outputs": [],
      "source": [
        "!pip install --quiet --upgrade tf-nightly\n",
        "!pip install --quiet --upgrade tensorflow-datasets"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "U13hRXHKTcsE"
      },
      "outputs": [],
      "source": [
        "import collections\n",
        "import functools\n",
        "import matplotlib.pyplot as plt\n",
        "import os\n",
        "import tempfile\n",
        "import tensorflow as tf\n",
        "import tensorflow.experimental.numpy as tnp\n",
        "import tensorflow_datasets as tfds\n",
        "\n",
        "gpus = tf.config.list_physical_devices('GPU')\n",
        "if gpus:\n",
        "  tf.config.set_logical_device_configuration(gpus[0], [\n",
        "      tf.config.LogicalDeviceConfiguration(memory_limit=128),\n",
        "      tf.config.LogicalDeviceConfiguration(memory_limit=128)])\n",
        "  devices = tf.config.list_logical_devices('GPU')\n",
        "else:\n",
        "  cpus = tf.config.list_physical_devices('CPU')\n",
        "  tf.config.set_logical_device_configuration(cpus[0], [\n",
        "      tf.config.LogicalDeviceConfiguration(),\n",
        "      tf.config.LogicalDeviceConfiguration()])\n",
        "  devices = tf.config.list_logical_devices('CPU')\n",
        "\n",
        "print(\"Using following virtual devices\", devices)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "AxNuZSqZKcdM"
      },
      "source": [
        "## Mnist dataset\n",
        "\n",
        "Mnist contains 28 * 28 images of digits from 0 to 9. The task is to classify the images as these 10 possible classes.\n",
        "\n",
        "Below, load the dataset and examine a few samples."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yKf9Tm5OjwGK"
      },
      "outputs": [],
      "source": [
        "NUM_CLASSES = 10\n",
        "BATCH_SIZE = 64\n",
        "INPUT_SIZE = 28 * 28\n",
        "\n",
        "def process_data(data_dict):\n",
        "  images = tnp.asarray(data_dict['image']) / 255.0\n",
        "  images = images.reshape(-1, INPUT_SIZE).astype(tnp.float32)\n",
        "  labels = tnp.asarray(data_dict['label'])\n",
        "  labels = tnp.eye(NUM_CLASSES, dtype=tnp.float32)[labels]\n",
        "  return images, labels\n",
        "\n",
        "with tf.device(\"CPU:0\"):\n",
        "  train_dataset = tfds.load('mnist', split='train', shuffle_files=True, \n",
        "                            batch_size=BATCH_SIZE).map(process_data)\n",
        "  test_dataset = tfds.load('mnist', split='test', shuffle_files=True, \n",
        "                          batch_size=-1)\n",
        "  x_test, y_test = process_data(test_dataset)\n",
        "\n",
        "  # Plots some examples.\n",
        "  images, labels = next(iter(train_dataset.take(1)))\n",
        "  _, axes = plt.subplots(1, 8, figsize=(12, 96))\n",
        "  for i, ax in enumerate(axes):\n",
        "    ax.imshow(images[i].reshape(28, 28), cmap='gray')\n",
        "    ax.axis(\"off\")\n",
        "    ax.set_title(\"Label: %d\" % int(tnp.argmax(labels[i])))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZDJQp4i00qaJ"
      },
      "source": [
        "## Define layers and model\n",
        "\n",
        "Here, you will implement a multi-layer perceptron model that trains on the MNIST data. First, define a `Dense` class which applies a linear transform followed by a \"relu\" non-linearity."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "44yzAmBFreyg"
      },
      "outputs": [],
      "source": [
        "class Dense(tf.Module):\n",
        "\n",
        "  def __init__(self, units, use_relu=True):\n",
        "    self.wt = None\n",
        "    self.bias = None\n",
        "    self._use_relu = use_relu\n",
        "    self._built = False\n",
        "    self._units = units\n",
        "\n",
        "  def __call__(self, inputs):\n",
        "    if not self._built:\n",
        "      self._build(inputs.shape)\n",
        "    x = tnp.add(tnp.matmul(inputs, self.wt), self.bias)\n",
        "    if self._use_relu:\n",
        "      return tnp.maximum(x, 0.)\n",
        "    else:\n",
        "      return x\n",
        "\n",
        "  @property\n",
        "  def params(self):\n",
        "    assert self._built\n",
        "    return [self.wt, self.bias]\n",
        "\n",
        "  def _build(self, input_shape):\n",
        "    size = input_shape[1]\n",
        "    stddev = 1 / tnp.sqrt(size)\n",
        "    # Note that model parameters are `tf.Variable` since they requires\n",
        "    # mutation, which is currently unsupported by TensorFlow NumPy.\n",
        "    # Also note interoperation with TensorFlow APIs below.\n",
        "    self.wt = tf.Variable(\n",
        "        tf.random.truncated_normal(\n",
        "            [size, self._units], stddev=stddev, dtype=tf.float32))\n",
        "    self.bias = tf.Variable(tf.zeros([self._units], dtype=tf.float32))\n",
        "    self._built = True"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wfKpg3adUCy9"
      },
      "source": [
        "Next, create a `Model` object that applies two non-linear `Dense` transforms,\n",
        "followed by a linear transform."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NdrdxKB7SenC"
      },
      "outputs": [],
      "source": [
        "class Model(tf.Module):\n",
        "  \"\"\"A  three layer neural network.\"\"\"\n",
        "\n",
        "  def __init__(self):\n",
        "    self.layer1 = Dense(128)\n",
        "    self.layer2 = Dense(32)\n",
        "    self.layer3 = Dense(NUM_CLASSES, use_relu=False)\n",
        "\n",
        "  def __call__(self, inputs):\n",
        "    x = self.layer1(inputs)\n",
        "    x = self.layer2(x)\n",
        "    return self.layer3(x)\n",
        "\n",
        "  @property\n",
        "  def params(self):\n",
        "    return self.layer1.params + self.layer2.params + self.layer3.params"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Hoxh5Z7E_9Pv"
      },
      "source": [
        "## Training and evaluation\n",
        "\n",
        "Checkout the following methods for performing training and evaluation."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hOxqjE7rZPdr"
      },
      "outputs": [],
      "source": [
        "def forward(model, inputs, labels):\n",
        "  \"\"\"Computes prediction and loss.\"\"\"\n",
        "  logits = model(inputs)\n",
        "  # TensorFlow's loss function has numerically stable implementation of forward\n",
        "  # pass and gradients. So we prefer that here.\n",
        "  loss = tf.nn.softmax_cross_entropy_with_logits(labels, logits)\n",
        "  mean_loss = tnp.mean(loss)\n",
        "  return logits, mean_loss\n",
        "\n",
        "def compute_gradients(model, inputs, labels):\n",
        "  \"\"\"Computes gradients of loss based on `labels` and prediction on `inputs`.\"\"\"\n",
        "  with tf.GradientTape() as tape:\n",
        "    tape.watch(inputs)\n",
        "    _, loss = forward(model, inputs, labels)\n",
        "  gradients = tape.gradient(loss, model.params)\n",
        "  return gradients\n",
        "\n",
        "def compute_sgd_updates(gradients, learning_rate):\n",
        "  \"\"\"Computes parameter updates based on SGD update rule.\"\"\"\n",
        "  return [-learning_rate * grad for grad in gradients]\n",
        "\n",
        "def apply_updates(model, updates):\n",
        "  \"\"\"Applies `update` to `model.params`.\"\"\"\n",
        "  for param, update in zip(model.params, updates):\n",
        "    param.assign_add(update)\n",
        "\n",
        "def evaluate(model, images, labels):\n",
        "  \"\"\"Evaluates accuracy for `model`'s predictions.\"\"\"\n",
        "  prediction = model(images)\n",
        "  predicted_class = tnp.argmax(prediction, axis=-1)\n",
        "  actual_class = tnp.argmax(labels, axis=-1)\n",
        "  return float(tnp.mean(predicted_class == actual_class))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8t70b5d6XCs7"
      },
      "source": [
        "### Single GPU training"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HrhS_M6kALeP"
      },
      "outputs": [],
      "source": [
        "NUM_EPOCHS = 10\n",
        "\n",
        "@tf.function\n",
        "def train_step(model, input, labels, learning_rate):\n",
        "  gradients = compute_gradients(model, input, labels)\n",
        "  updates = compute_sgd_updates(gradients, learning_rate)\n",
        "  apply_updates(model, updates)\n",
        "\n",
        "# Creates and build a model.\n",
        "model = Model()\n",
        "\n",
        "accuracies = []\n",
        "for _ in range(NUM_EPOCHS):\n",
        "  for inputs, labels in train_dataset:\n",
        "    train_step(model, inputs, labels, learning_rate=0.1)\n",
        "  accuracies.append(evaluate(model, x_test, y_test))\n",
        "\n",
        "def plot_accuracies(accuracies):\n",
        "  plt.plot(accuracies)\n",
        "  plt.xlabel(\"epoch\")\n",
        "  plt.ylabel(\"accuracy\")\n",
        "  plt.title(\"Eval accuracy vs epoch\")\n",
        "\n",
        "plot_accuracies(accuracies)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Dw7RwQmKcYK9"
      },
      "source": [
        "#### Saving the models to disk"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rmk2xLQLcXkl"
      },
      "outputs": [],
      "source": [
        "# A temporary directory to save our models into.\n",
        "dir = tempfile.TemporaryDirectory()\n",
        "\n",
        "# We take our model, and create a wrapper for it.\n",
        "class SaveableModel(Model):\n",
        "  @tf.function\n",
        "  def __call__(self, inputs):\n",
        "    return super().__call__(inputs)\n",
        "\n",
        "saveable_model = SaveableModel()\n",
        "\n",
        "# This saves a concrete function that we care about.\n",
        "outputs = saveable_model(x_test)\n",
        "\n",
        "# This saves the model to disk.\n",
        "tf.saved_model.save(saveable_model, dir.name)\n",
        "\n",
        "loaded = tf.saved_model.load(dir.name)\n",
        "outputs_loaded = loaded(x_test)\n",
        "\n",
        "# Ensure that the loaded model preserves the weights\n",
        "# of the saved model.\n",
        "assert tnp.allclose(outputs, outputs_loaded)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ak_hCOkGXXfl"
      },
      "source": [
        "### Multi GPU runs\n",
        "\n",
        "Next, run mirrored training on multiple GPUs. Note that the GPUs used here are virtual and map to the same physical GPU.\n",
        "\n",
        "First, define a few utilities to run replicated computation and reductions."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ujbeT5p6Xm7k"
      },
      "source": [
        "#### Distribution primitives\n",
        "\n",
        "Checkout primitives below for function replication and distributed reduction."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MZ6hivj-ZIRo"
      },
      "outputs": [],
      "source": [
        "import threading\n",
        "import queue\n",
        "\n",
        "# Note that this code currently relies on dispatching operations from python\n",
        "# threads.\n",
        "class ReplicatedFunction(object):\n",
        "  \"\"\"Creates a callable that will run `fn` on each device in `devices`.\"\"\"\n",
        "\n",
        "  def __init__(self, fn, devices, **kw_args):\n",
        "    self._shutdown = False\n",
        "\n",
        "    def _replica_fn(device, input_queue, output_queue):\n",
        "      while not self._shutdown:\n",
        "        inputs = input_queue.get()\n",
        "        with tf.device(device):\n",
        "          output_queue.put(fn(*inputs, **kw_args))\n",
        "\n",
        "    self.threads = []\n",
        "    self.input_queues = [queue.Queue() for _ in devices]\n",
        "    self.output_queues = [queue.Queue() for _ in devices]\n",
        "    for i, device in enumerate(devices):\n",
        "      thread = threading.Thread(\n",
        "          target=_replica_fn,\n",
        "          args=(device, self.input_queues[i], self.output_queues[i]))\n",
        "      thread.start()\n",
        "      self.threads.append(thread)\n",
        "\n",
        "  def __call__(self, *inputs):\n",
        "    all_inputs = zip(*inputs)\n",
        "    for input_queue, replica_input, in zip(self.input_queues, all_inputs):\n",
        "      input_queue.put(replica_input)\n",
        "    return [q.get() for q in self.output_queues]\n",
        "\n",
        "  def __del__(self):\n",
        "    self._shutdown = True\n",
        "    for t in self.threads:\n",
        "      t.join(3)\n",
        "    self.threads = None\n",
        "\n",
        "def collective_mean(inputs, num_devices):\n",
        "  \"\"\"Performs collective mean reduction on inputs.\"\"\"\n",
        "  outputs = []\n",
        "  for instance_key, inp in enumerate(inputs):\n",
        "    outputs.append(tnp.asarray(\n",
        "      tf.raw_ops.CollectiveReduce(\n",
        "          input=inp, group_size=num_devices, group_key=0,\n",
        "          instance_key=instance_key, merge_op='Add', final_op='Div',\n",
        "          subdiv_offsets=[])))\n",
        "  return outputs"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1ZiN1rpJYHLu"
      },
      "source": [
        "#### Distributed training "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "A6ZHYmLapunm"
      },
      "outputs": [],
      "source": [
        "# This is similar to `train_step` except for an extra collective reduction of\n",
        "# gradients\n",
        "@tf.function\n",
        "def replica_step(model, inputs, labels,\n",
        "                 learning_rate=None, num_devices=None):\n",
        "  gradients = compute_gradients(model, inputs, labels)\n",
        "  # Note that each replica performs a reduction to compute mean of gradients.\n",
        "  reduced_gradients = collective_mean(gradients, num_devices)\n",
        "  updates = compute_sgd_updates(reduced_gradients, learning_rate)\n",
        "  apply_updates(model, updates)\n",
        "\n",
        "models = [Model() for _ in devices]\n",
        "\n",
        "# The code below builds all the model objects and copies model parameters from\n",
        "# the first model to all the replicas.\n",
        "def init_model(model):\n",
        "  model(tnp.zeros((1, INPUT_SIZE), dtype=tnp.float32))\n",
        "  if model != models[0]:\n",
        "    # Copy the first models weights into the other models.\n",
        "    for p1, p2 in zip(model.params, models[0].params):\n",
        "      p1.assign(p2)\n",
        "\n",
        "with tf.device(devices[0]):\n",
        "  init_model(models[0])\n",
        "# Replicate and run the parameter initialization.\n",
        "ReplicatedFunction(init_model, devices[1:])(models[1:])\n",
        "\n",
        "# Replicate the training step\n",
        "replicated_step = ReplicatedFunction(\n",
        "    replica_step, devices, learning_rate=0.1, num_devices=len(devices))\n",
        "\n",
        "accuracies = []\n",
        "print(\"Running distributed training on devices: %s\" % devices)\n",
        "for _ in range(NUM_EPOCHS):\n",
        "  for inputs, labels in train_dataset:\n",
        "    replicated_step(models,\n",
        "                    tnp.split(inputs, len(devices)),\n",
        "                    tnp.split(labels, len(devices)))\n",
        "  accuracies.append(evaluate(models[0], x_test, y_test))\n",
        "\n",
        "plot_accuracies(accuracies)"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [
        "KQALG9h23b0R",
        "f7NApJ7R3ndN"
      ],
      "name": "TensorFlow Numpy: Distributed Image Classification",
      "private_outputs": true,
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
